{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import os\n",
    "from os.path import join as j_\n",
    "from PIL import Image\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# loading all packages here to start\n",
    "from uni import get_encoder\n",
    "from uni.downstream.extract_patch_features import extract_patch_features_from_dataloader\n",
    "from uni.downstream.eval_patch_features.linear_probe import eval_linear_probe\n",
    "from uni.downstream.eval_patch_features.fewshot import eval_knn, eval_fewshot\n",
    "from uni.downstream.eval_patch_features.protonet import ProtoNet, prototype_topk_vote\n",
    "from uni.downstream.eval_patch_features.metrics import get_eval_metrics, print_metrics\n",
    "from uni.downstream.utils import concat_images\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Downloading UNI weights + Creating Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Method 1: Following authentication (using ```huggingface_hub```), the ViT-L/16 model architecture with pretrained weights and image transforms for UNI can be directly loaded using the [timm](https://huggingface.co/docs/hub/en/timm) library. This method automatically downloads the model weights to the [huggingface_hub cache](https://huggingface.co/docs/huggingface_hub/en/guides/manage-cache) in your home directory (```~/.cache/huggingface/hub/models--MahmoodLab--UNI2-h```), which ```timm``` will automatically find when using the commands below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm\n",
    "from timm.data import resolve_data_config\n",
    "from timm.data.transforms_factory import create_transform\n",
    "from huggingface_hub import login\n",
    "\n",
    "login()  # login with your User Access Token, found at https://huggingface.co/settings/tokens\n",
    "\n",
    "# pretrained=True needed to load UNI weights (and download weights for the first time)\n",
    "# init_values need to be passed in to successfully load LayerScale parameters (e.g. - block.0.ls1.gamma)\n",
    "model = timm.create_model(\"hf-hub:MahmoodLab/UNI2-h\", pretrained=True, init_values=1e-5, dynamic_img_size=True)\n",
    "transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))\n",
    "model.eval()\n",
    "model.to(device)\n",
    "transform"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Method 2: You can also download the model weights to a specified checkpoint location in your local directory. The ```timm``` library is still used for defining the ViT-L/16 model architecture. Pretrained weights and image transforms for UNI need to be manually loaded and defined.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from torchvision import transforms\n",
    "import timm\n",
    "from huggingface_hub import login, hf_hub_download\n",
    "\n",
    "login()  # login with your User Access Token, found at https://huggingface.co/settings/tokens\n",
    "\n",
    "local_dir = \"../assets/ckpts/uni2-h/\"\n",
    "os.makedirs(local_dir, exist_ok=True)  # create directory if it does not exist\n",
    "hf_hub_download(\"MahmoodLab/UNI2-h\", filename=\"pytorch_model.bin\", local_dir=local_dir, force_download=True)\n",
    "timm_kwargs = {'img_size': 224, \n",
    "               'patch_size': 14, \n",
    "               'depth': 24,\n",
    "               'num_heads': 24,\n",
    "               'init_values': 1e-5, \n",
    "               'embed_dim': 1536,\n",
    "               'mlp_ratio': 2.66667*2,\n",
    "               'num_classes': 0, \n",
    "               'no_embed_class': True,\n",
    "               'mlp_layer': timm.layers.SwiGLUPacked, \n",
    "               'act_layer': torch.nn.SiLU, \n",
    "               'reg_tokens': 8, \n",
    "               'dynamic_img_size': True\n",
    "              }\n",
    "model = timm.create_model(**timm_kwargs)\n",
    "model.load_state_dict(torch.load(os.path.join(local_dir, \"pytorch_model.bin\"), map_location=\"cpu\"), strict=True)\n",
    "model.eval()\n",
    "model.to(device)\n",
    "transform = transforms.Compose(\n",
    "    [\n",
    "        transforms.Resize(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n",
    "    ]\n",
    ")\n",
    "transform"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function `get_encoder` performs the commands above, downloading in the checkpoint in the `./assets/ckpts/` relative path of this GitHub repository."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from uni import get_encoder\n",
    "model, transform = get_encoder(enc_name='uni2-h', device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ROI Feature Extraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from uni.downstream.extract_patch_features import extract_patch_features_from_dataloader\n",
    "\n",
    "# get path to example data\n",
    "dataroot = '../assets/data/tcga_luadlusc'\n",
    "\n",
    "# create some image folder datasets for train/test and their data laoders\n",
    "train_dataset = torchvision.datasets.ImageFolder(j_(dataroot, 'train'), transform=transform)\n",
    "test_dataset = torchvision.datasets.ImageFolder(j_(dataroot, 'test'), transform=transform)\n",
    "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=False)\n",
    "test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False)\n",
    "\n",
    "# extract patch features from the train and test datasets (returns dictionary of embeddings and labels)\n",
    "train_features = extract_patch_features_from_dataloader(model, train_dataloader)\n",
    "test_features = extract_patch_features_from_dataloader(model, test_dataloader)\n",
    "\n",
    "# convert these to torch\n",
    "train_feats = torch.Tensor(train_features['embeddings'])\n",
    "train_labels = torch.Tensor(train_features['labels']).type(torch.long)\n",
    "test_feats = torch.Tensor(test_features['embeddings'])\n",
    "test_labels = torch.Tensor(test_features['labels']).type(torch.long)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ROI Linear Probe Evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from uni.downstream.eval_patch_features.linear_probe import eval_linear_probe\n",
    "\n",
    "linprobe_eval_metrics, linprobe_dump = eval_linear_probe(\n",
    "    train_feats = train_feats,\n",
    "    train_labels = train_labels,\n",
    "    valid_feats = None ,\n",
    "    valid_labels = None,\n",
    "    test_feats = test_feats,\n",
    "    test_labels = test_labels,\n",
    "    max_iter = 1000,\n",
    "    verbose= True,\n",
    ")\n",
    "\n",
    "print_metrics(linprobe_eval_metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ROI KNN and ProtoNet evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from uni.downstream.eval_patch_features.fewshot import eval_knn\n",
    "\n",
    "knn_eval_metrics, knn_dump, proto_eval_metrics, proto_dump = eval_knn(\n",
    "    train_feats = train_feats,\n",
    "    train_labels = train_labels,\n",
    "    test_feats = test_feats,\n",
    "    test_labels = test_labels,\n",
    "    center_feats = True,\n",
    "    normalize_feats = True,\n",
    "    n_neighbors = 20\n",
    ")\n",
    "\n",
    "print_metrics(knn_eval_metrics)\n",
    "print_metrics(proto_eval_metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ROI Few-Shot Evaluation (based on ProtoNet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from uni.downstream.eval_patch_features.fewshot import eval_fewshot\n",
    "\n",
    "fewshot_episodes, fewshot_dump = eval_fewshot(\n",
    "    train_feats = train_feats,\n",
    "    train_labels = train_labels,\n",
    "    test_feats = test_feats,\n",
    "    test_labels = test_labels,\n",
    "    n_iter = 500, # draw 500 few-shot episodes\n",
    "    n_way = 2, # use all class examples\n",
    "    n_shot = 4, # 4 examples per class (as we don't have that many)\n",
    "    n_query = test_feats.shape[0], # evaluate on all test samples\n",
    "    center_feats = True,\n",
    "    normalize_feats = True,\n",
    "    average_feats = True,\n",
    ")\n",
    "\n",
    "# how well we did picking 4 random examples per class\n",
    "display(fewshot_episodes)\n",
    "\n",
    "# summary\n",
    "display(fewshot_dump)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A Closer Look at ProtoNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can use ProtoNet in a sklearn-like API as well for fitting and predicting models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from uni.downstream.eval_patch_features.protonet import ProtoNet\n",
    "\n",
    "# fitting the model\n",
    "proto_clf = ProtoNet(metric='L2', center_feats=True, normalize_feats=True)\n",
    "proto_clf.fit(train_feats, train_labels)\n",
    "print('What our prototypes look like', proto_clf.prototype_embeddings.shape)\n",
    "\n",
    "# evaluating the model\n",
    "test_pred = proto_clf.predict(test_feats)\n",
    "get_eval_metrics(test_labels, test_pred, get_report=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using `proto_clf._get_topk_queries_inds`, we use the test samples as the query set, and get the top-k queries to each prototype, effectively doing ROI retrieval."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist, topk_inds = proto_clf._get_topk_queries_inds(test_feats, topk=5)\n",
    "print('label2idx correspondenes', train_dataset.class_to_idx)\n",
    "test_imgs_df = pd.DataFrame(test_dataset.imgs, columns=['path', 'label'])\n",
    "\n",
    "print('Top-k LUAD-like test samples to LUAD prototype')\n",
    "luad_topk_inds = topk_inds[0]\n",
    "luad_topk_imgs = concat_images([Image.open(img_fpath) for img_fpath in test_imgs_df['path'][luad_topk_inds]], scale=0.5)\n",
    "display(luad_topk_imgs)\n",
    "\n",
    "print('Top-k LUSC-like test samples to LUSC prototype')\n",
    "lusc_topk_inds = topk_inds[1]\n",
    "lusc_topk_imgs = concat_images([Image.open(img_fpath) for img_fpath in test_imgs_df['path'][lusc_topk_inds]], scale=0.5)\n",
    "display(lusc_topk_imgs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using `proto_clf._get_topk_prototypes_inds`, we can instead use the prototypes as the query set, and get the top-k queries to each test sample. With k set to # of prototypes / labels, we are essentially doing ROI classification (assigning label of the nearest prototype to the test sample)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dist, topk_inds = proto_clf._get_topk_prototypes_inds(test_feats, topk=2)\n",
    "print(\"The top-2 closest prototypes to each test sample, with closer prototypes first (left hand side)\")\n",
    "display(topk_inds)\n",
    "\n",
    "print('Labels of the top-1 closest prototypes')\n",
    "pred_test = topk_inds[:, 0]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "UNI",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
